package org.example.table;

import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.example.domin.Result;
import org.example.utils.JDBCUtils;

import java.sql.Connection;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import static org.example.concost.Concost.*;
import static org.example.utils.TimeConversion.convertUToTime;

/**
 * @author lwc
 * @description: TODO
 * @date 2023/12/21 20:41
 */
@Slf4j
public class HiveCurveData {

    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf().setAppName("Spark HBase Example");
        SparkSession spark = null;
        if (args.length >= 1) {
            spark = SparkSession.builder()
                    .appName("HBase to MySQL Sync")
                    .config(sparkConf)
//                    .config("spark.master", "local[*]")  // 使用本地模式，[*]表示使用所有可用的核心
                    .enableHiveSupport()
                    .getOrCreate();
            log.info("加上了local[*]");
        } else {
            spark = SparkSession.builder()
                    .appName("HBase to MySQL Sync")
                    .config(sparkConf)
                    .config("spark.master", "local[*]")  // 使用本地模式，[*]表示使用所有可用的核心
                    .enableHiveSupport()
                    .getOrCreate();
        }
        Data(spark, sparkConf);
        spark.stop();
    }


    public static void Data(SparkSession spark, SparkConf sparkConf) {
        log.info("=================开始查询E_MP_U_CURVE数据===================");
        Dataset<Row> databases = spark.sql("show databases");
        databases.show();
        String database = sparkConf.get("spark.app.database", "ods_amr20_hbase");
        Dataset<Row> use = spark.sql("use " + database);
        String startDate = DateUtil.endOfMonth(DateUtil.date()).toString("yyyyMMdd");
        // 获取三个月后的日期
        String endDate = DateUtil.offsetMonth(DateUtil.parse(DateUtil.beginOfMonth(DateUtil.date()).toDateStr()), -3).toString("yyyyMMdd");
        String table = sparkConf.get("spark.app.table", "e_mp_pf_curve");
        String startTime = sparkConf.get("spark.app.startTime", endDate);
        String endTime = sparkConf.get("spark.app.endTime", startDate);
        log.info("输入的库名称为：{},表名称为：{}", database, table);
        try {
            String sql = String.format("select count(*) as count from ods_amr20_hbase.%s  where ds between '%s' " +
                    "and '%s'", table, startTime, endTime);
            log.info("查询:{}表count的语句为：{}", table, sql);
            Dataset<Row> tableCount = spark.sql(sql);
            tableCount.show();
            Optional<Row> count = tableCount.collectAsList().stream().findFirst();
            count.ifPresent(row -> whileInsert(row.getLong(0), spark, sparkConf, table, startTime, endTime));
        } catch (Exception e) {
            log.error("e_mp_u_curve报错:{}", e.getMessage());
            e.printStackTrace();
        }

        log.info("=================开始查询E_MP_U_CURVE 结束===================");
    }

    /**
     * 循环插入数据，直到数据插入完毕
     *
     * @param count
     * @param spark
     * @param sparkConf
     * @param table
     * @param startDate
     * @param endDate
     */
    private static void whileInsert(long count, SparkSession spark, SparkConf sparkConf, String table, String startDate, String endDate) {
        String url = sparkConf.get("spark.app.url", "jdbc:mysql://localhost:3306/test");
//        String username = sparkConf.get("spark.app.username", "sjzt_dws_yypdtqznjk_w");
//        String password = sparkConf.get("spark.app.password", "YyPdtqznjk_2023#$");
//        String dwsTable = sparkConf.get("spark.app.dwsTable", table);
        String username = sparkConf.get("spark.app.username", "root");
        String password = sparkConf.get("spark.app.password", "9XME3z94xs9nhCj");
        String dwsTable = sparkConf.get("spark.app.dwsTable", "t_obj_202312ld_hour");
        String type = sparkConf.get("spark.app.type", "10128E15");
        log.info("接收到url的路径为：{},username为：{},password为：{},dwsTable为：{},tpye为：{}", url, username, password, dwsTable, type);
        log.info("总共的数据为：{}", count);
        long pageNum = 0;
        Long pageSize = sparkConf.getLong("spark.app.pageSize", 1000);
        Long page = pageSize;
        log.info("接收到的pageSize为：{}", pageSize);
        while (true) {
            // 使用窗口函数进行查询
            String sql = String.format("SELECT * FROM (SELECT *, ROW_NUMBER() OVER (ORDER BY ds desc) as row_num FROM ods_amr20_hbase.%s" +
                    " where ds between '%s' and '%s' ) tmp where row_num >= %d and row_num <= %d", table, startDate, endDate, pageNum, pageSize);
            log.info("分页查询的语句为：{}", sql);
            Dataset<Row> sqlData = spark.sql(sql);
            long pageCount = sqlData.count();
            log.info("分页查询的数量：{}", pageCount);
            if (pageCount == 0) {
                //结束循环
                break;
            }

            sqlData.foreach(data -> {
                Connection connection = JDBCUtils.getConnection(url, username, password);
                try {
                    String sqlDataBuild = String.format("insert into %s (f_data_collection_time,f_data_input_time,f_measurement_points,f_key_name,f_delete", dwsTable);
                    Result result = tableGetResult(data, table, sqlDataBuild);
                    if (result != null) sqlDataBuild = result.getSql();
                    if (ObjectUtil.isNotNull(result)) {
                        String values = joinValues(data, result);
                        if (sqlDataBuild != null && StrUtil.isNotBlank(values)) {
                            sqlDataBuild = sqlDataBuild.concat(values);
                            if (result.getPhaseFlag() != null) {
                                sqlDataBuild = sqlAppendUpdate(sqlDataBuild, result.getPhaseFlag());
                            }
//                            log.info("插入的语句为：{}", sqlDataBuild);
                            int ai = JDBCUtils.executeUpdate(connection, sqlDataBuild);
                            log.info("插入的情况为：{}", ai);
                            sqlDataBuild = String.format("insert into %s (f_data_collection_time,f_data_input_time,f_measurement_points,f_key_name,f_delete", dwsTable);
                        } else {
                            log.info("table为：{},phase_flag为：{}values为空：{}", table, getRowObject(data, "phase_flag"), values);
                        }
                    }
                } catch (Exception e) {
                    log.error("循环出错：{}", e.getMessage());
                    e.printStackTrace();
                } finally {
                    connection.close();
                }
            });


            // 定义转换函数
//            RowConverter converter = new RowConverter();
//            JavaRDD<Row> javaRDD = sqlData.javaRDD().flatMap(row ->converter.transformRow(row).iterator());
////            Dataset<Row> processedRDD = spark.createDataFrame(javaRDD, schema);
//            // 创建 DataFrame
//            SparkSession newSpark = SparkSession.builder().config(sparkConf).getOrCreate();
//
//            Dataset<Row> processedDF = newSpark.createDataFrame(javaRDD, getStructType());
//
//            // 选择需要的字段
//            String[] selectedColumns = {"f_data_collection_time", "f_data_input_time", "f_measurement_points", "data_va", "data_vb", "data_vc", "f_delete"};
//            processedDF = processedDF.selectExpr(selectedColumns);
//            processedDF.show();
//
//            processedDF.write()
//                    .format("jdbc")
//                    .option("url", url)
//                    .option("dbtable", dwsTable)
//                    .option("user", username)
//                    .option("password", password)
//                    .mode(SaveMode.Append)
//                    .save();
            pageNum = pageSize;
            pageSize = pageSize + page;
            log.info("pageNum,pageSize");
        }

    }

    private static String joinValues(Row data, Result result) {
        String sqlDataBuild = "";
        Object point = getRowObject(data, "point");
        Object dataDate = getRowObject(data, "data_date");
        for (int i = 0; i < data.length(); i++) {
            String columnName = result.getType() + String.format("%04d", i);
            // 先用phaseFlag 进行判断，是那个数据
            Object dataDouble = getRowObject(data, columnName);
            if (ObjectUtil.isNotNull(dataDouble)) {
                sqlDataBuild = sqlDataBuild.concat(String.format(" ('%s' ,'%s' ,%d ,'%s',0,%f ),", dataDate + " " + convertUToTime(columnName), dataDate + " "
                        + convertUToTime(columnName), Long.parseLong(point.toString()), result.getKeyName(), Double.parseDouble(dataDouble.toString())));
            }
        }
        return sqlDataBuild;
    }

    private static Result tableGetResult(Row data, String table, String sqlDataBuild) {
        // 在这里需要重新获取 表判断的字段
        String phaseFlag = (String) getRowObject(data, "phase_flag");
        if (table.equals(EMPUCURVE)) {
            return new Result(phaseFlag, E215, "u", joinPirxd(sqlDataBuild, phaseFlag));
        } else if (table.equals(EMPPFCURVE)) {
            return new Result(phaseFlag, E89, "c", hourSqlJoin(sqlDataBuild, phaseFlag));
        } else if (table.equals(EMPICURVE)) {
//            joinPirxd(sqlDataBuild, phaseFlag);
            return new Result(phaseFlag, E25, "i", joinPirxd(sqlDataBuild, phaseFlag));
        } else {
            return null;
        }

    }

    private static String hourSqlJoin(String sqlDataBuild, String phaseFlag) {
        if (phaseFlag != null) {
            switch (phaseFlag) {
                case S10128E25:
                    sqlDataBuild = sqlDataBuild.concat(",data_vs) VALUES ");
                    break;
                case A10128E25:
                    sqlDataBuild = sqlDataBuild.concat(",data_va) VALUES ");
                    break;
                case B10128E25:
                    sqlDataBuild = sqlDataBuild.concat(",data_vb) VALUES ");
                    break;
                case C10128E25:
                    sqlDataBuild = sqlDataBuild.concat(",data_vc) VALUES ");
                    break;
            }
        }
        return sqlDataBuild;
    }

    private static String joinPirxd(String sqlDataBuild, String phaseFlag) {
        if (phaseFlag != null) {
            switch (phaseFlag) {
                case A10128E25:
                    return sqlDataBuild.concat(",data_va) VALUES ");
                case B10128E25:
                    return sqlDataBuild.concat(",data_vb) VALUES ");
                case C10128E25:
                    return sqlDataBuild.concat(",data_vc) VALUES ");
            }
        }
        return null;
    }

    private static String sqlAppendUpdate(String dataVaSql, String phaseFlag) {
        String substring = dataVaSql.substring(0, dataVaSql.lastIndexOf(","));
        switch (phaseFlag) {
            case S10128E25:
                return substring = substring.concat(" ON DUPLICATE KEY UPDATE f_measurement_points = VALUES(f_measurement_points)," +
                        "f_data_collection_time = values(f_data_collection_time),f_data_input_time = values(f_data_input_time)," +
                        "f_key_name = values(f_key_name),data_vs = values (data_vs),f_delete = values (f_delete)");
            case A10128E25:
                return substring = substring.concat(" ON DUPLICATE KEY UPDATE f_measurement_points = VALUES(f_measurement_points)," +
                        "f_data_collection_time = values(f_data_collection_time),f_data_input_time = values(f_data_input_time)," +
                        "f_key_name = values(f_key_name),data_va = values (data_va),f_delete = values (f_delete)");
            case B10128E25:

                return substring = substring.concat(" ON DUPLICATE KEY UPDATE f_measurement_points = VALUES(f_measurement_points)," +
                        "f_data_collection_time = values(f_data_collection_time),f_data_input_time = values(f_data_input_time)," +
                        "f_key_name = values(f_key_name),data_vb = values(data_vb),f_delete = values (f_delete)");
            case C10128E25:
                return substring = substring.concat(" ON DUPLICATE KEY UPDATE f_measurement_points = VALUES(f_measurement_points)," +
                        "f_data_collection_time = values(f_data_collection_time),f_data_input_time = values(f_data_input_time)," +
                        "f_key_name = values(f_key_name),data_vc = values (data_vc),f_delete = values (f_delete)");
            default:
                return null;
        }
    }

    /**
     * 拼接sql
     *
     * @param data
     * @param result
     * @param dataVaSql
     */
    private static String getsqlData(Row data, Result result, String dataVaSql) {

        dataVaSql = hourSqlJoin(dataVaSql, result.getPhaseFlag());

        if (result.getPhaseFlag() != null) {
            return dataVaSql;
        } else {
            return null;
        }

    }

    /**
     * 取出数据
     *
     * @param data
     * @param point
     * @return
     */
    private static Object getRowObject(Row data, String point) {
        if (!data.schema().getFieldIndex(point).toList().isEmpty() &&
                ObjectUtil.isNotNull(data.getAs(point))) {
            return data.getAs(point);
        } else {
            return null;
        }
    }

    private static StructType getStructType() {
        // 定义 MySQL 中的表结构
        StructType schema = DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("f_data_collection_time", DataTypes.StringType, true),
                DataTypes.createStructField("f_data_input_time", DataTypes.StringType, true),
                DataTypes.createStructField("f_measurement_points", DataTypes.StringType, true),
                DataTypes.createStructField("data_va", DataTypes.DoubleType, true),
                DataTypes.createStructField("data_vb", DataTypes.DoubleType, true),
                DataTypes.createStructField("data_vc", DataTypes.DoubleType, true),
                DataTypes.createStructField("f_delete", DataTypes.StringType, true)
        });
        return schema;
    }

    // 定义转换函数类
    static class RowConverter implements java.io.Serializable {
        public List<Row> transformRow(Row row) {
            List<Row> resultRows = new ArrayList<>();
            // 定义你的转换逻辑，根据需要修改
            String type = row.getString(row.fieldIndex("phase_flag"));
            if (A10128E25.equals(type)) {
                setRddList(row, resultRows, A10128E25);
            } else if (B10128E25.equals(type)) {
                setRddList(row, resultRows, B10128E25);
            } else if (C10128E25.equals(type)) {
                setRddList(row, resultRows, C10128E25);
            }
            return resultRows;
        }
    }

    private static void setRddList(Row row, List<Row> resultRows, String a10128e25) {
        Double dataVa = null;
        for (int i = 0; i < row.length(); i++) {
            String columnName = "u" + String.format("%04d", i);
            int size = row.schema().getFieldIndex(columnName).toList().size();
            if (size > 0) {
                String datav = row.getString(row.fieldIndex(columnName));
                if (ObjectUtil.isNotNull(datav)) {
                    dataVa = Double.parseDouble(datav);
                    if (a10128e25.equals("1")) {
                        resultRows.add(RowFactory.create(
                                row.getString(row.fieldIndex("data_date")) + " " + convertUToTime(columnName),
                                row.getString(row.fieldIndex("data_date")) + " " + convertUToTime(columnName),
                                row.getString(row.fieldIndex("point")),
                                dataVa,
                                0.0,
                                0.0,
                                "0"
                        ));

                    } else if (a10128e25.equals("2")) {
                        resultRows.add(RowFactory.create(
                                row.getString(row.fieldIndex("push_date")) + " " + convertUToTime(columnName),
                                row.getString(row.fieldIndex("push_date")) + " " + convertUToTime(columnName),
                                row.getString(row.fieldIndex("point")),
                                0.0,
                                dataVa,
                                0.0,
                                "0"
                        ));

                    } else if (a10128e25.equals("3")) {
                        resultRows.add(RowFactory.create(
                                row.getString(row.fieldIndex("push_date")) + " " + convertUToTime(columnName),
                                row.getString(row.fieldIndex("push_date")) + " " + convertUToTime(columnName),
                                row.getString(row.fieldIndex("point")),
                                0.0,
                                0.0,
                                dataVa,
                                "0"
                        ));

                    }
                }
            }
        }
    }
}
